{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.auto import tqdm\n",
    "import itertools\n",
    "import os\n",
    "from copy import deepcopy\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dicts(entities):\n",
    "    entity2ind = dict()\n",
    "    ind2entity = []\n",
    "    for i in range(len(entities)):\n",
    "        entity = entities[i]\n",
    "        if not (entity in ind2entity):\n",
    "            ind2entity.append(entity)\n",
    "            entity2ind[entity] = len(ind2entity) - 1\n",
    "    return ind2entity, entity2ind\n",
    "\n",
    "def choose(arr, ratio_or_count):\n",
    "    if type(ratio_or_count) == float:\n",
    "        num = round(ratio_or_count*len(arr))\n",
    "    elif type(ratio_or_count) == int:\n",
    "        num = ratio_or_count\n",
    "    else:\n",
    "         assert False\n",
    "    if num >= len(arr):\n",
    "        return arr\n",
    "    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()\n",
    "    return [arr[i] for i in rand_inds]\n",
    "    \n",
    "def split(arr, ratio_or_count):\n",
    "    if type(ratio_or_count) == float:\n",
    "        num = round(ratio_or_count*len(arr))\n",
    "    elif type(ratio_or_count) == int:\n",
    "        num = ratio_or_count\n",
    "    else:\n",
    "         assert False\n",
    "    train, test = [], []\n",
    "    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()\n",
    "    for i in tqdm(range(len(arr))):\n",
    "        if i in rand_inds:\n",
    "            train.append(arr[i])\n",
    "        else:\n",
    "            test.append(arr[i])\n",
    "    return [train, test]\n",
    "\n",
    "def form_items(c, t):\n",
    "    input_text = \"\".join(c)\n",
    "    target_text = input_text + \"\".join([t, \"</a>\"])\n",
    "    item = {\n",
    "        \"input_text\": input_text,\n",
    "        \"target_text\": target_text\n",
    "    }\n",
    "    return item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(num_entities, num_relations, out_degree=20, split_train_inferred=False):\n",
    " \n",
    "    entities = [\"<e_{}>\".format(i) for i in range(num_entities)]\n",
    "    ind2entity, entity2ind = build_dicts(entities)\n",
    "\n",
    "    relations = [\"<r_{}>\".format(i) for i in range(num_relations)]\n",
    "    ind2relation, relation2ind = build_dicts(relations)\n",
    "\n",
    "    atomic_dict = dict()   # maps a head entity to a list of (r, t) pairs\n",
    "    atomic_facts = []\n",
    "    atomics = []\n",
    "\n",
    "    for i in tqdm(range(num_entities)):\n",
    "        # for each subject entity, randomly select some outgoing relations to some random object entity\n",
    "        num_rows = out_degree\n",
    "        selected_rows = np.random.choice(num_relations, size=num_rows, replace=False).tolist()\n",
    "        for row_idx in selected_rows:\n",
    "            col_idx = np.random.randint(num_entities)  # pick some random tail entity for each selected (h,r)\n",
    "            h,r,t = ind2entity[i], ind2relation[row_idx], ind2entity[col_idx]\n",
    "            atomic_facts.append(form_items([h, r], t))\n",
    "            atomics.append((h,r,t))\n",
    "            if h not in atomic_dict:\n",
    "                atomic_dict[h] = []\n",
    "            atomic_dict[h].append((r, t))\n",
    "    if not split_train_inferred:\n",
    "        inferred_facts = []\n",
    "        for ent in tqdm(entities):\n",
    "            for (r1, b) in atomic_dict[ent]:\n",
    "                for (r2, t) in atomic_dict[b]:\n",
    "                    inferred_facts.append(form_items([ent, r1, r2], t))\n",
    "        return entities, relations, atomic_facts, inferred_facts\n",
    "    \n",
    "    # split ID/OOD\n",
    "    OOD_ratio = 0.05\n",
    "    OOD_facts, ID_facts = split(atomics, round(len(atomics)*OOD_ratio))\n",
    "    OOD_facts, ID_facts = set(OOD_facts), set(ID_facts)\n",
    "\n",
    "    id_atomic_facts = [form_items([h, r], t) for (h,r,t) in ID_facts]\n",
    "    ood_atomic_facts = [form_items([h, r], t) for (h,r,t) in OOD_facts]\n",
    "\n",
    "    train_inferred_facts, test_inferred_iid, test_inferred_ood = [], [], []\n",
    "    for ent in tqdm(entities):\n",
    "        for (r1, b) in atomic_dict[ent]:\n",
    "            for (r2, t) in atomic_dict[b]:\n",
    "                if (ent, r1, b) in OOD_facts or (b, r2, t) in OOD_facts:\n",
    "                    if (ent, r1, b) in OOD_facts and (b, r2, t) in OOD_facts:\n",
    "                        test_inferred_ood.append(form_items([ent, r1, r2], t))\n",
    "                    continue\n",
    "                if np.random.uniform() > 0.005:\n",
    "                    train_inferred_facts.append(form_items([ent, r1, r2], t))\n",
    "                else:\n",
    "                    test_inferred_iid.append(form_items([ent, r1, r2], t))\n",
    "\n",
    "    return entities, relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_ood \n",
    "    \n",
    "NUM_ENTITY_IN = 2000\n",
    "NUM_RELATION = 200\n",
    "\n",
    "train_entities, train_relations, id_atomic_facts, ood_atomic_facts, train_inferred_facts, test_inferred_iid, test_inferred_facts = build_dataset(NUM_ENTITY_IN, NUM_RELATION, split_train_inferred=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = []\n",
    "vocab = vocab + train_entities + train_relations\n",
    "# special tokens\n",
    "vocab = vocab + [\"<mask>\", \"<sep>\", \"<a>\", \"</a>\", \"<q>\", \"</q>\"]\n",
    "assert len(vocab) == len(set(vocab))\n",
    "print(\"vocab size:\", len(vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_size = 3000\n",
    "id_atomic_facts_ds = choose(id_atomic_facts, test_size)\n",
    "ood_atomic_facts_ds = choose(ood_atomic_facts, test_size)\n",
    "test_inferred_iid = choose(test_inferred_iid, test_size)\n",
    "test_inferred_facts_ds = choose(test_inferred_facts, test_size)\n",
    "\n",
    "all_atomics = id_atomic_facts + ood_atomic_facts\n",
    "len(all_atomics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# downsampling train_inferred\n",
    "for phi in [18.0,12.6,9.0,7.2,5.4,3.6]:\n",
    "    dataset_name = \"composition.{}.{}.{}\".format(NUM_ENTITY_IN, NUM_RELATION, phi)\n",
    "    os.makedirs(\"data/{}\".format(dataset_name), exist_ok=True)\n",
    "    train_inferred_facts_ds = choose(train_inferred_facts, round(phi * len(id_atomic_facts)))\n",
    "\n",
    "    probes = []\n",
    "    for item in id_atomic_facts_ds:\n",
    "        probes.append(deepcopy(item))\n",
    "        probes[-1][\"type\"] = \"id_atomic\"\n",
    "    \n",
    "    for item in ood_atomic_facts_ds:\n",
    "        probes.append(deepcopy(item))\n",
    "        probes[-1][\"type\"] = \"ood_atomic\"\n",
    "\n",
    "    for item in choose(train_inferred_facts_ds, test_size):\n",
    "        probes.append(deepcopy(item))\n",
    "        probes[-1]['type'] = 'train_inferred'\n",
    "\n",
    "    for item in test_inferred_iid:\n",
    "        probes.append(deepcopy(item))\n",
    "        probes[-1]['type'] = 'test_inferred_iid'\n",
    "\n",
    "    for item in test_inferred_facts_ds:\n",
    "        probes.append(deepcopy(item))\n",
    "        probes[-1][\"type\"] = \"test_inferred_ood\"\n",
    "\n",
    "    with open(\"data/{}/train.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(all_atomics + train_inferred_facts_ds, f)\n",
    "    with open(\"data/{}/valid.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(test_inferred_facts_ds, f)\n",
    "    with open(\"data/{}/test.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(probes, f)\n",
    "    # add vocab\n",
    "    with open(\"data/{}/vocab.json\".format(dataset_name), \"w\", encoding='utf-8') as f:\n",
    "        json.dump(vocab, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CLM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
